! (C) Copyright 2022- NVIDIA.
!
! This software is licensed under the terms of the Apache Licence Version 2.0
! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
! In applying this licence, ECMWF does not waive the privileges and immunities
! granted to it by virtue of its status as an intergovernmental organisation
! nor does it submit to any jurisdiction.
module openacc_ext_type
  use iso_c_binding, only: c_size_t
  implicit none
  private
  public :: ext_acc_arr_desc

  ! to my knowledge, this cannot be part of openacc_ext
  type ext_acc_arr_desc
    integer(c_size_t) :: ptr, sz
  end type
end module
module openacc_ext
  use iso_c_binding, only: c_ptr, c_size_t, c_loc
  use openacc, only: acc_create, acc_copyin, acc_handle_kind
  use openacc_ext_type, only: ext_acc_arr_desc
  implicit none

  private
  public :: ext_acc_pass, ext_acc_create, ext_acc_copyin, ext_acc_copyout, &
    & ext_acc_delete, ext_acc_arr_desc, acc_handle_kind

  type common_pointer_descr
    type(c_ptr) :: ptr
    integer(c_size_t) :: sz
  end type

  interface ext_acc_pass
    module procedure ext_acc_pass_2d_r4, ext_acc_pass_3d_r4, ext_acc_pass_4d_r4, ext_acc_pass_2d_r8, ext_acc_pass_3d_r8, ext_acc_pass_4d_r8
  end interface
contains

  function ext_acc_pass_2d_r4(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(4), intent(in) :: arr(:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 2) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2)))
    ptr2 = c_loc(arr(lbound(arr,1)+1, lbound(arr,2)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 1)
  end function
  function ext_acc_pass_3d_r4(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(4), intent(in) :: arr(:,:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 3) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2)+1, lbound(arr,3)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 2)
  end function
  function ext_acc_pass_4d_r4(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(4), intent(in) :: arr(:,:,:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 4) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3)+1, lbound(arr,4)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 3)
  end function
  function ext_acc_pass_2d_r8(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(8), intent(in) :: arr(:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 2) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2)))
    ptr2 = c_loc(arr(lbound(arr,1)+1, lbound(arr,2)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 1)
  end function
  function ext_acc_pass_3d_r8(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(8), intent(in) :: arr(:,:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 3) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr,3)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2)+1, lbound(arr,3)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 2)
  end function
  function ext_acc_pass_4d_r8(arr) result(ret)
    implicit none
    type(ext_acc_arr_desc) :: ret
    real(8), intent(in) :: arr(:,:,:,:)

    type(c_ptr) :: ptr1, ptr2
    integer(c_size_t) :: ptr1_v, ptr2_v

    ! get full slices for all but the last slice
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)+1))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)

    ret%ptr = ptr1_v
    ret%sz = (ptr2_v - ptr1_v) * (size(arr, 4) - 1)

    ! for the last slice, take the actual offset, otherwise we imght go OOB
    ptr1 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3), lbound(arr,4)))
    ptr2 = c_loc(arr(lbound(arr,1), lbound(arr,2), lbound(arr, 3)+1, lbound(arr,4)))
    ptr1_v= transfer(ptr1, ptr1_v)
    ptr2_v= transfer(ptr2, ptr2_v)
    ret%sz = ret%sz + (ptr2_v - ptr1_v) * size(arr, 3)
  end function
  function get_common_pointers(in_ptrs, out_ptrs) result(num_ranges)
    implicit none
    type(ext_acc_arr_desc), intent(in) :: in_ptrs(:)
    type(common_pointer_descr), intent(out) :: out_ptrs(:)

    integer(c_size_t), allocatable :: ptrs_only(:)
    logical, allocatable :: mask(:)
    integer, allocatable :: sort_index(:)

    type(ext_acc_arr_desc), allocatable :: common_ptrs(:)
    integer :: i, j, num_ranges
    integer(c_size_t) :: start1, start2, end1, end2
    logical :: found

    ! first sort the pointers increasingly such that no gaps are possible
    allocate(ptrs_only(size(in_ptrs)))
    do i = 1, size(in_ptrs)
      ptrs_only(i) = in_ptrs(i)%ptr
    enddo
    allocate(mask(size(in_ptrs)))
    do i = 1, size(in_ptrs)
      mask(i) = .true.
    enddo
    allocate(sort_index(size(in_ptrs)))
    do i = 1, size(in_ptrs)
      j = minloc(ptrs_only, 1, mask=mask)
      mask(j) = .false.
      sort_index(i) = j
    enddo

    ! initialize
    allocate(common_ptrs(size(in_ptrs)))
    do i = 1, size(in_ptrs)
      common_ptrs(1)%ptr = 0
      common_ptrs(1)%sz = 0
    enddo

    num_ranges = 1
    common_ptrs(1) = in_ptrs(sort_index(1))
    do i = 2, size(in_ptrs)
      found = .false.
      start1 = in_ptrs(sort_index(i))%ptr
      end1 = in_ptrs(sort_index(i))%ptr + in_ptrs(sort_index(i))%sz
      do j = 1, num_ranges
        start2 = common_ptrs(j)%ptr
        end2 = common_ptrs(j)%ptr + common_ptrs(j)%sz
        if (max(start1, start2) <= min(end1, end2)) then
          ! if we intersect with this range, extend the range
          common_ptrs(j)%ptr = min(start1, start2)
          common_ptrs(j)%sz = max(end1, end2) - common_ptrs(j)%ptr
          found = .true.
          exit
        endif
      enddo
      if (.not. found) then
        ! if we did not find anything: add a new one
        num_ranges = num_ranges + 1
        common_ptrs(num_ranges)%ptr = start1
        common_ptrs(num_ranges)%sz = end1 - start1
      endif
    enddo
    do i = 1, num_ranges
      out_ptrs(i)%ptr = transfer(common_ptrs(i)%ptr, out_ptrs(i)%ptr)
      out_ptrs(i)%sz = common_ptrs(i)%sz
    enddo
  end function
  subroutine ext_acc_create(ptrs, stream)
    use openacc, only: acc_create, acc_async_sync
    use iso_fortran_env, only: int32
    implicit none
    type(ext_acc_arr_desc), intent(in) :: ptrs(:)
    integer(acc_handle_kind), optional :: stream

    type(common_pointer_descr), allocatable :: common_ptrs(:)

    integer :: i, num_ranges
    integer(kind=int32), pointer :: pp(:)
    integer(acc_handle_kind) :: stream_act

    if (present(stream)) then
      stream_act = stream
    else
      stream_act = acc_async_sync
    endif
    allocate(common_ptrs(size(ptrs)))
    num_ranges = get_common_pointers(ptrs, common_ptrs)

    do i = 1, num_ranges
      call c_f_pointer(common_ptrs(i)%ptr, pp, shape=[common_ptrs(i)%sz/sizeof(pp(1))])
      !!call acc_create_async(pp, common_ptrs(i)%sz, async=stream_act)
      call acc_create(pp, int(common_ptrs(i)%sz))
    enddo
  end subroutine
  subroutine ext_acc_copyin(ptrs, stream)
    use openacc, only: acc_async_sync
    implicit none
    type(ext_acc_arr_desc), intent(in) :: ptrs(:)
    integer(acc_handle_kind), optional :: stream

    type(common_pointer_descr), allocatable :: common_ptrs(:)

    integer :: i, num_ranges
    integer(4), pointer :: pp(:)

    integer(acc_handle_kind) :: stream_act

    if (present(stream)) then
      stream_act = stream
    else
      stream_act = acc_async_sync
    endif
    allocate(common_ptrs(size(ptrs)))
    num_ranges = get_common_pointers(ptrs, common_ptrs)

    do i = 1, num_ranges
      call c_f_pointer(common_ptrs(i)%ptr, pp, shape=[common_ptrs(i)%sz/sizeof(pp(1))])
      !!call acc_copyin_async(pp, common_ptrs(i)%sz, async=stream_act)
      call acc_copyin(pp, int(common_ptrs(i)%sz))
    enddo
  end subroutine
  subroutine ext_acc_copyout(ptrs, stream)
    use openacc, only: acc_async_sync, acc_copyout
    implicit none
    type(ext_acc_arr_desc), intent(in) :: ptrs(:)
    integer(acc_handle_kind), optional :: stream

    type(common_pointer_descr), allocatable :: common_ptrs(:)

    integer :: i, num_ranges
    integer(4), pointer :: pp(:)

    integer(acc_handle_kind) :: stream_act

    if (present(stream)) then
      stream_act = stream
    else
      stream_act = acc_async_sync
    endif
    allocate(common_ptrs(size(ptrs)))
    num_ranges = get_common_pointers(ptrs, common_ptrs)

    do i = 1, num_ranges
      call c_f_pointer(common_ptrs(i)%ptr, pp, shape=[common_ptrs(i)%sz/sizeof(pp(1))])
      !!call acc_copyout_async(pp, common_ptrs(i)%sz, async=stream_act)
      call acc_copyout(pp, int(common_ptrs(i)%sz))
    enddo
  end subroutine
  subroutine ext_acc_delete(ptrs, stream)
    use openacc, only: acc_async_sync, acc_delete
    implicit none
    type(ext_acc_arr_desc), intent(in) :: ptrs(:)
    integer(acc_handle_kind), optional :: stream

    type(common_pointer_descr), allocatable :: common_ptrs(:)

    integer :: i, num_ranges
    integer(4), pointer :: pp(:)

    integer(acc_handle_kind) :: stream_act

    if (present(stream)) then
      stream_act = stream
    else
      stream_act = acc_async_sync
    endif
    allocate(common_ptrs(size(ptrs)))
    num_ranges = get_common_pointers(ptrs, common_ptrs)

    do i = 1, num_ranges
      call c_f_pointer(common_ptrs(i)%ptr, pp, shape=[common_ptrs(i)%sz/sizeof(pp(1))])
      !!call acc_delete_async(pp, common_ptrs(i)%sz, async=stream_act)
      call acc_delete(pp, int(common_ptrs(i)%sz))
    enddo
  end subroutine
end module
